{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "第二章-第六节-autograd\n",
    "本代码是[PyTorch实用教程第二版](https://tingsongyu.github.io/PyTorch-Tutorial-2nd)的配套代码，请结合书中内容学习，效果更佳。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "heading_collapsed": true
   },
   "source": [
    "## torch.autograd.backward"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "hidden": true
   },
   "source": [
    "torch.autograd.backward(tensors, grad_tensors=None, retain_graph=None, create_graph=False, grad_variables=None, inputs=None)\n",
    "\n",
    "* tensors (Sequence[Tensor] or Tensor) – Tensors of which the derivative will be computed.\n",
    "\n",
    "* grad_tensors (Sequence[Tensor or None] or Tensor, optional) – The “vector” in the Jacobian-vector product, usually gradients w.r.t. each element of corresponding tensors. None values can be specified for scalar Tensors or ones that don’t require grad. If a None value would be acceptable for all grad_tensors, then this argument is optional.\n",
    "\n",
    "* retain_graph (bool, optional) – If False, the graph used to compute the grad will be freed. Note that in nearly all cases setting this option to True is not needed and often can be worked around in a much more efficient way. Defaults to the value of create_graph.\n",
    "\n",
    "* create_graph (bool, optional) – If True, graph of the derivative will be constructed, allowing to compute higher order derivative products. Defaults to False.\n",
    "\n",
    "* inputs (Sequence[Tensor] or Tensor, optional) – Inputs w.r.t. which the gradient be will accumulated into .grad. All other Tensors will be ignored. If not provided, the gradient is accumulated into all the leaf Tensors that were used to compute the attr::tensors."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "heading_collapsed": true,
    "hidden": true
   },
   "source": [
    "### retain_grad参数使用\n",
    "对比两个代码段，仔细阅读pytorch报错信息。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([5.])\n",
      "tensor([10.])\n"
     ]
    }
   ],
   "source": [
    "#####  retain_graph=True\n",
    "import torch\n",
    "w = torch.tensor([1.], requires_grad=True)\n",
    "x = torch.tensor([2.], requires_grad=True)\n",
    "\n",
    "a = torch.add(w, x)\n",
    "b = torch.add(w, 1)\n",
    "y = torch.mul(a, b)  \n",
    "\n",
    "y.backward(retain_graph=True)\n",
    "print(w.grad)\n",
    "y.backward()\n",
    "print(w.grad)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "hidden": true
   },
   "source": [
    "运行上面代码段可以看到是正常的，下面这个代码段就会报错，报错信息提示非常明确：Trying to backward through the graph a second time。并且还给出了解决方法： Specify retain_graph=True if you need to backward through the graph a second time 。  \n",
    "这也是pytorch代码写得好的地方，出现错误不要慌，仔细看看报错信息，里边可能会有解决问题的方法。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "collapsed": true,
    "hidden": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([5.])\n"
     ]
    },
    {
     "ename": "RuntimeError",
     "evalue": "Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
      "\u001b[1;32m<ipython-input-9-64bccc64184d>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m     10\u001b[0m \u001b[0my\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     11\u001b[0m \u001b[0mprint\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mw\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mgrad\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 12\u001b[1;33m \u001b[0my\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     13\u001b[0m \u001b[0mprint\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mw\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mgrad\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mD:\\Anaconda_data\\envs\\pytorch_1.10_gpu\\lib\\site-packages\\torch\\_tensor.py\u001b[0m in \u001b[0;36mbackward\u001b[1;34m(self, gradient, retain_graph, create_graph, inputs)\u001b[0m\n\u001b[0;32m    305\u001b[0m                 \u001b[0mcreate_graph\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mcreate_graph\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    306\u001b[0m                 inputs=inputs)\n\u001b[1;32m--> 307\u001b[1;33m         \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mautograd\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mgradient\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0minputs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    308\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    309\u001b[0m     \u001b[1;32mdef\u001b[0m \u001b[0mregister_hook\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mhook\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mD:\\Anaconda_data\\envs\\pytorch_1.10_gpu\\lib\\site-packages\\torch\\autograd\\__init__.py\u001b[0m in \u001b[0;36mbackward\u001b[1;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001b[0m\n\u001b[0;32m    154\u001b[0m     Variable._execution_engine.run_backward(\n\u001b[0;32m    155\u001b[0m         \u001b[0mtensors\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mgrad_tensors_\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 156\u001b[1;33m         allow_unreachable=True, accumulate_grad=True)  # allow_unreachable flag\n\u001b[0m\u001b[0;32m    157\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    158\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;31mRuntimeError\u001b[0m: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward."
     ]
    }
   ],
   "source": [
    "#####  retain_graph=False\n",
    "import torch\n",
    "w = torch.tensor([1.], requires_grad=True)\n",
    "x = torch.tensor([2.], requires_grad=True)\n",
    "\n",
    "a = torch.add(w, x)\n",
    "b = torch.add(w, 1)\n",
    "y = torch.mul(a, b)\n",
    "\n",
    "y.backward()\n",
    "print(w.grad)\n",
    "y.backward()\n",
    "print(w.grad)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "hidden": true
   },
   "source": [
    "### grad_tensors使用\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([9.])\n"
     ]
    }
   ],
   "source": [
    "w = torch.tensor([1.], requires_grad=True)\n",
    "x = torch.tensor([2.], requires_grad=True)\n",
    "\n",
    "a = torch.add(w, x)     \n",
    "b = torch.add(w, 1)\n",
    "\n",
    "y0 = torch.mul(a, b)    # y0 = (x+w) * (w+1)    dy0/dw = 2w + x + 1\n",
    "y1 = torch.add(a, b)    # y1 = (x+w) + (w+1)    dy1/dw = 2\n",
    "\n",
    "loss = torch.cat([y0, y1], dim=0)       # [y0, y1]\n",
    "\n",
    "grad_tensors = torch.tensor([1., 2.])\n",
    "\n",
    "loss.backward(gradient=grad_tensors)    # Tensor.backward中的 gradient 传入 torch.autograd.backward()中的grad_tensors\n",
    "\n",
    "# w =  1* (dy0/dw)  +   2*(dy1/dw)\n",
    "# w =  1* (2w + x + 1)  +   2*(w)\n",
    "# w =  1* (5)  +   2*(2)\n",
    "# w =  9\n",
    "\n",
    "print(w.grad)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "heading_collapsed": true
   },
   "source": [
    "## torch.autograd.grad"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(tensor([6.], grad_fn=<MulBackward0>),)\n",
      "(tensor([2.]),)\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "x = torch.tensor([3.], requires_grad=True)\n",
    "y = torch.pow(x, 2)     # y = x**2\n",
    "\n",
    "# 一阶导数\n",
    "grad_1 = torch.autograd.grad(y, x, create_graph=True)   # grad_1 = dy/dx = 2x = 2 * 3 = 6\n",
    "print(grad_1)\n",
    "\n",
    "# 二阶导数\n",
    "grad_2 = torch.autograd.grad(grad_1[0], x)              # grad_2 = d(dy/dx)/dx = d(2x)/dx = 2\n",
    "print(grad_2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "heading_collapsed": true
   },
   "source": [
    "## torch.autograd.Function"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "hidden": true
   },
   "source": [
    "有的时候，想要实现自己的一些操作（op），如特殊的数学函数、pytorch的module中没有的网络层，那就需要自己写一个Function，在Function中定义好forward的计算公式、backward的计算公式，然后将这些op组合到模型中，模型就可以用autograd完成梯度求取。\n",
    "\n",
    "这个概念还是很抽象，下面将采用4个实例进行讲解，大家多运行代码体会Function的用处。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "heading_collapsed": true,
    "hidden": true
   },
   "source": [
    "#### 案例1： exp"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "hidden": true
   },
   "source": [
    "案例1：来自 https://pytorch.org/docs/stable/autograd.html#function<br>\n",
    "假设需要一个计算指数的功能，并且能组合到模型中，实现autograd，那么可以这样实现  \n",
    "  \n",
    "第一步：继承Function  \n",
    "第二步：实现forward  \n",
    "第三步：实现backward  \n",
    "  \n",
    "注意事项：\n",
    "1. forward和backward函数第一个参数为**ctx**，它的作用类似于类函数的self一样，更详细解释可参考如下：\n",
    "In the forward pass we receive a Tensor containing the input and return a Tensor containing the output. ctx is a context object that can be used to stash information for backward computation. You can cache arbitrary objects for use in the backward pass using the ctx.save_for_backward method.  \n",
    "\n",
    "2. backward函数返回的参数个数与forward的输入参数个数相同, 即，传入该op的参数，都需要给它们计算对应的梯度。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([2.7183], grad_fn=<ExpBackward>)\n",
      "tensor([2.7183])\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "from torch.autograd.function import Function\n",
    "\n",
    "class Exp(Function):\n",
    "    @staticmethod\n",
    "    def forward(ctx, i):\n",
    "        \n",
    "        # ============== step1: 函数功能实现 ==============\n",
    "        result = i.exp()\n",
    "        # ============== step1: 函数功能实现 ==============\n",
    "        \n",
    "        # ============== step2: 结果保存，用于反向传播 ==============\n",
    "        ctx.save_for_backward(result)\n",
    "        # ============== step2: 结果保存，用于反向传播 ==============\n",
    "        \n",
    "        return result\n",
    "    @staticmethod\n",
    "    def backward(ctx, grad_output):\n",
    "        \n",
    "        # ============== step1: 取出结果，用于反向传播 ==============\n",
    "        result, = ctx.saved_tensors\n",
    "        # ============== step1: 取出结果，用于反向传播 ==============\n",
    "        \n",
    "\n",
    "        # ============== step2: 反向传播公式实现 ==============\n",
    "        grad_results = grad_output * result\n",
    "        # ============== step2: 反向传播公式实现 ==============\n",
    "\n",
    "\n",
    "        return grad_results\n",
    "\n",
    "x = torch.tensor([1.], requires_grad=True)  \n",
    "y = Exp.apply(x)                          # 需要使用apply方法调用自定义autograd function\n",
    "print(y)                                  #  y = e^x = e^1 = 2.7183\n",
    "y.backward()                            \n",
    "print(x.grad)                           # 反传梯度,  x.grad = dy/dx = e^x = e^1  = 2.7183\n",
    "\n",
    "# 关于本例子更详细解释，推荐阅读 https://zhuanlan.zhihu.com/p/321449610"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "hidden": true
   },
   "source": [
    "从代码里可以看到，y这个张量的 **grad_fn** 是 **ExpBackward**，正是我们自己实现的函数，这表明当y求梯度时，会调用**ExpBackward**这个函数进行计算  \n",
    "这也是张量的grad_fn的作用所在"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "heading_collapsed": true,
    "hidden": true
   },
   "source": [
    "#### 案例2：为梯度乘以一定系数 Gradcoeff"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "hidden": true
   },
   "source": [
    "案例2来自： https://zhuanlan.zhihu.com/p/321449610\n",
    "\n",
    "功能是反向传梯度时乘以一个自定义系数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([4.], grad_fn=<PowBackward0>)\n",
      "tensor([-0.4000])\n"
     ]
    }
   ],
   "source": [
    "class GradCoeff(Function):       \n",
    "       \n",
    "    @staticmethod\n",
    "    def forward(ctx, x, coeff):                 \n",
    "        \n",
    "        # ============== step1: 函数功能实现 ==============\n",
    "        ctx.coeff = coeff   # 将coeff存为ctx的成员变量\n",
    "        x.view_as(x)\n",
    "        # ============== step1: 函数功能实现 ==============\n",
    "        return x\n",
    "\n",
    "    @staticmethod\n",
    "    def backward(ctx, grad_output):            \n",
    "        return ctx.coeff * grad_output, None    # backward的输出个数，应与forward的输入个数相同，此处coeff不需要梯度，因此返回None\n",
    "\n",
    "# 尝试使用\n",
    "x = torch.tensor([2.], requires_grad=True)\n",
    "ret = GradCoeff.apply(x, -0.1)                  # 前向需要同时提供x及coeff，设置coeff为-0.1\n",
    "ret = ret ** 2                          \n",
    "print(ret)                                      # 注意看： ret.grad_fn \n",
    "ret.backward()  \n",
    "print(x.grad)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "hidden": true
   },
   "source": [
    "在这里需要注意 backward函数返回的参数个数与forward的输入参数个数相同  \n",
    "即，**传入该op的参数，都需要给它们计算对应的梯度**。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "heading_collapsed": true,
    "hidden": true
   },
   "source": [
    "#### 案例3：勒让德多项式"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "hidden": true
   },
   "source": [
    "案例来自：https://github.com/excelkks/blog  \n",
    "假设多项式为：$y = a+bx+cx^2+dx^3$时，用两步替代该过程 $y= a+b\\times P_3(c+dx), P_3(x) = \\frac{1}{2}(5x^3-3x)$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "127.0\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import math\n",
    "from torch.autograd.function import Function\n",
    "\n",
    "class LegendrePolynomial3(Function):\n",
    "    @staticmethod\n",
    "    def forward(ctx, x):\n",
    "        \"\"\"\n",
    "        In the forward pass we receive a Tensor containing the input and return\n",
    "        a Tensor containing the output. ctx is a context object that can be used\n",
    "        to stash information for backward computation. You can cache arbitrary\n",
    "        objects for use in the backward pass using the ctx.save_for_backward method.\n",
    "        \"\"\"\n",
    "        y = 0.5 * (5 * x ** 3 - 3 * x)\n",
    "        ctx.save_for_backward(x)\n",
    "        return y\n",
    "\n",
    "    @staticmethod\n",
    "    def backward(ctx, grad_output):\n",
    "        \"\"\"\n",
    "        In the backward pass we receive a Tensor containing the gradient of the loss\n",
    "        with respect to the output, and we need to compute the gradient of the loss\n",
    "        with respect to the input.\n",
    "        \"\"\"\n",
    "        ret, = ctx.saved_tensors\n",
    "        return grad_output * 1.5 * (5 * ret ** 2 - 1)\n",
    "\n",
    "a, b, c, d = 1, 2, 1, 2 \n",
    "x = 1\n",
    "P3 = LegendrePolynomial3.apply\n",
    "y_pred = a + b * P3(c + d * x)\n",
    "print(y_pred)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "hidden": true
   },
   "source": [
    "#### 案例4：手动实现2D卷积"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "hidden": true
   },
   "source": [
    "案例来自：https://pytorch.org/tutorials/intermediate/custom_function_conv_bn_tutorial.html  \n",
    "案例本是卷积与BN的融合实现，此处仅观察Function的使用，更详细的内容，十分推荐阅读原文章   \n",
    "下面看如何实现conv_2d的"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {
    "hidden": true
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch.autograd.function import once_differentiable\n",
    "import torch.nn.functional as F\n",
    "\n",
    "\n",
    "def convolution_backward(grad_out, X, weight):\n",
    "    \"\"\"\n",
    "    将反向传播功能用函数包装起来，返回的参数个数与forward接收的参数个数保持一致，为2个\n",
    "    \"\"\"\n",
    "    grad_input = F.conv2d(X.transpose(0, 1), grad_out.transpose(0, 1)).transpose(0, 1)\n",
    "    grad_X = F.conv_transpose2d(grad_out, weight)\n",
    "    return grad_X, grad_input\n",
    "\n",
    "class MyConv2D(torch.autograd.Function):\n",
    "    @staticmethod\n",
    "    def forward(ctx, X, weight):\n",
    "        ctx.save_for_backward(X, weight)\n",
    "        \n",
    "        # ============== step1: 函数功能实现 ==============\n",
    "        ret = F.conv2d(X, weight) \n",
    "        # ============== step1: 函数功能实现 ==============\n",
    "        return ret\n",
    "\n",
    "    @staticmethod\n",
    "    def backward(ctx, grad_out):\n",
    "        X, weight = ctx.saved_tensors\n",
    "        return convolution_backward(grad_out, X, weight)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {
    "collapsed": true,
    "hidden": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "None\n",
      "tensor([[[[1.4503, 1.3995, 1.4427],\n",
      "          [1.4725, 1.4247, 1.4995],\n",
      "          [1.4584, 1.4395, 1.5462]],\n",
      "\n",
      "         [[1.5675, 1.5959, 1.5277],\n",
      "          [1.5069, 1.5671, 1.5381],\n",
      "          [1.4938, 1.5488, 1.5162]],\n",
      "\n",
      "         [[1.6349, 1.5952, 1.5464],\n",
      "          [1.6047, 1.5859, 1.5455],\n",
      "          [1.5829, 1.5640, 1.5735]]],\n",
      "\n",
      "\n",
      "        [[[1.2860, 1.2742, 1.3006],\n",
      "          [1.3162, 1.2824, 1.3495],\n",
      "          [1.3139, 1.2992, 1.3743]],\n",
      "\n",
      "         [[1.4297, 1.4514, 1.4030],\n",
      "          [1.3533, 1.3968, 1.3824],\n",
      "          [1.3670, 1.3744, 1.3706]],\n",
      "\n",
      "         [[1.4531, 1.4284, 1.3653],\n",
      "          [1.4276, 1.4299, 1.3891],\n",
      "          [1.4120, 1.4208, 1.4355]]],\n",
      "\n",
      "\n",
      "        [[[1.1845, 1.1387, 1.1964],\n",
      "          [1.1977, 1.1693, 1.2149],\n",
      "          [1.2042, 1.1640, 1.2450]],\n",
      "\n",
      "         [[1.2977, 1.3085, 1.2787],\n",
      "          [1.2251, 1.2535, 1.2254],\n",
      "          [1.2069, 1.2499, 1.2202]],\n",
      "\n",
      "         [[1.2932, 1.2685, 1.2270],\n",
      "          [1.2960, 1.2863, 1.2718],\n",
      "          [1.2611, 1.2894, 1.2867]]],\n",
      "\n",
      "\n",
      "        [[[1.4025, 1.3540, 1.3942],\n",
      "          [1.4003, 1.3713, 1.4468],\n",
      "          [1.4061, 1.4031, 1.4953]],\n",
      "\n",
      "         [[1.5125, 1.5212, 1.4972],\n",
      "          [1.4671, 1.5192, 1.5010],\n",
      "          [1.4668, 1.4805, 1.4737]],\n",
      "\n",
      "         [[1.5746, 1.5602, 1.4785],\n",
      "          [1.5366, 1.5206, 1.4942],\n",
      "          [1.5228, 1.5425, 1.5140]]],\n",
      "\n",
      "\n",
      "        [[[1.3284, 1.2822, 1.3114],\n",
      "          [1.3055, 1.2790, 1.3726],\n",
      "          [1.3167, 1.2998, 1.3996]],\n",
      "\n",
      "         [[1.4537, 1.4507, 1.3772],\n",
      "          [1.3982, 1.4149, 1.3937],\n",
      "          [1.3328, 1.4037, 1.3717]],\n",
      "\n",
      "         [[1.4645, 1.4461, 1.3604],\n",
      "          [1.4523, 1.4556, 1.3755],\n",
      "          [1.4204, 1.4346, 1.4323]]]], dtype=torch.float64)\n"
     ]
    }
   ],
   "source": [
    "weight = torch.rand(5, 3, 3, 3, requires_grad=True, dtype=torch.double)\n",
    "X = torch.rand(10, 3, 7, 7, requires_grad=True, dtype=torch.double)\n",
    "torch.autograd.gradcheck(Conv2D.apply, (X, weight))  # gradcheck 功能请自行了解，通常写完Function会用它检查一下\n",
    "y = Conv2D.apply(X, weight)\n",
    "label = torch.randn_like(y)\n",
    "loss = F.mse_loss(y, label)\n",
    "\n",
    "print(weight.grad)\n",
    "loss.backward()\n",
    "print(weight.grad)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## autograd相关的知识点"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "autograd使用过程中还有很多需要注意的地方，在这里做个小汇总。\n",
    "* 知识点一：梯度不会自动清零  \n",
    "* 知识点二： 依赖于叶子结点的结点，requires_grad默认为True  \n",
    "* 知识点三： 叶子结点不可执行in-place \n",
    "* 知识点四： detach 的作用\n",
    "* 知识点五： with torch.no_grad()的作用"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "heading_collapsed": true
   },
   "source": [
    "####  知识点一：梯度不会自动清零 "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([5.])\n",
      "tensor([5.])\n",
      "tensor([5.])\n",
      "tensor([5.])\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "w = torch.tensor([1.], requires_grad=True)\n",
    "x = torch.tensor([2.], requires_grad=True)\n",
    "\n",
    "for i in range(4):\n",
    "    a = torch.add(w, x)\n",
    "    b = torch.add(w, 1)\n",
    "    y = torch.mul(a, b)\n",
    "\n",
    "    y.backward()   \n",
    "    print(w.grad)  # 梯度不会自动清零，数据会累加， 通常需要采用 optimizer.zero_grad() 完成对参数的梯度清零\n",
    "\n",
    "#     w.grad.zero_()     "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "heading_collapsed": true
   },
   "source": [
    "#### 知识点二：依赖于叶子结点的结点，requires_grad默认为True  "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "hidden": true
   },
   "source": [
    "结点的运算依赖于叶子结点的话，它一定是要计算梯度的，因为叶子结点梯度的计算是从后向前传播的，因此与其相关的结点均需要计算梯度，这点还是很好理解的。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "True True True\n",
      "False False False\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "w = torch.tensor([1.], requires_grad=True)  # \n",
    "x = torch.tensor([2.], requires_grad=True)\n",
    "\n",
    "a = torch.add(w, x)\n",
    "b = torch.add(w, 1)\n",
    "y = torch.mul(a, b)\n",
    "\n",
    "print(a.requires_grad, b.requires_grad, y.requires_grad)\n",
    "print(a.is_leaf, b.is_leaf, y.is_leaf)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "heading_collapsed": true
   },
   "source": [
    "#### 知识点三：叶子张量不可以执行in-place操作"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "hidden": true
   },
   "source": [
    "叶子结点不可执行in-place，因为计算图的backward过程都依赖于叶子结点的计算，可以回顾计算图当中的例子，所有的偏微分计算所需要用到的数据都是基于w和x（叶子结点），因此叶子结点不允许in-place操作。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2361561191752 tensor([1.])\n",
      "2362180999432 tensor([2.])\n",
      "2362180999432 tensor([3.])\n"
     ]
    }
   ],
   "source": [
    "a = torch.ones((1, ))\n",
    "print(id(a), a)\n",
    "\n",
    "a = a + torch.ones((1, ))\n",
    "print(id(a), a)\n",
    "\n",
    "a += torch.ones((1, ))\n",
    "print(id(a), a)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {
    "hidden": true
   },
   "outputs": [
    {
     "ename": "RuntimeError",
     "evalue": "a leaf Variable that requires grad is being used in an in-place operation.",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
      "\u001b[1;32m<ipython-input-41-7e2ec3c17fc3>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m      6\u001b[0m \u001b[0my\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mmul\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0ma\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mb\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m      7\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 8\u001b[1;33m \u001b[0mw\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0madd_\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m      9\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     10\u001b[0m \u001b[0my\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;31mRuntimeError\u001b[0m: a leaf Variable that requires grad is being used in an in-place operation."
     ]
    }
   ],
   "source": [
    "w = torch.tensor([1.], requires_grad=True)\n",
    "x = torch.tensor([2.], requires_grad=True)\n",
    "\n",
    "a = torch.add(w, x)\n",
    "b = torch.add(w, 1)\n",
    "y = torch.mul(a, b)\n",
    "\n",
    "w.add_(1)\n",
    "\n",
    "y.backward()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 知识点四：detach 的作用"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "通过以上知识，我们知道计算图中的张量是不能随便修改的，否则会造成计算图的backward计算错误，那有没有其他方法能修改呢？当然有，那就是detach()  \n",
    "\n",
    "detach的作用是：从计算图中剥离出“数据”，并以一个新张量的形式返回，**并且**新张量与旧张量共享数据，简单的可理解为做了一个别名。\n",
    "请看下例的w，detach后对w_detach修改数据，w同步地被改为了999\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([999.], requires_grad=True)\n"
     ]
    }
   ],
   "source": [
    "w = torch.tensor([1.], requires_grad=True)\n",
    "x = torch.tensor([2.], requires_grad=True)\n",
    "\n",
    "a = torch.add(w, x)\n",
    "b = torch.add(w, 1)\n",
    "y = torch.mul(a, b)\n",
    "\n",
    "y.backward()\n",
    "\n",
    "w_detach = w.detach()\n",
    "w_detach.data[0] = 999\n",
    "print(w)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 知识点五：with torch.no_grad()的作用"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "autograd自动构建计算图过程中会保存一系列中间变量，以便于backward的计算，这就必然需要花费额外的内存和时间。  \n",
    "而并不是所有情况下都需要backward，例如推理的时候，因此可以采用上下文管理器——torch.no_grad()来管理上下文，让pytorch不记录相应的变量，以加快速度和节省空间。  \n",
    "详见：https://pytorch.org/docs/stable/generated/torch.no_grad.html?highlight=no_grad#torch.no_grad"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pytorch_1.10_gpu",
   "language": "python",
   "name": "pytorch_1.10_gpu"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.13"
  },
  "pycharm": {
   "stem_cell": {
    "cell_type": "raw",
    "metadata": {
     "collapsed": false
    },
    "source": []
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
